from __future__ import print_function, division
import math

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, models, transforms
import time
import os
from efficientnet.model import EfficientNet
import cv2
import numpy as np
import shutil
from PIL import Image

# some parameters
use_gpu = torch.cuda.is_available()
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
data_dir = '../clothes_classify'
batch_size = 1
lr = 0.01
momentum = 0.9
num_epochs = 60
input_size = 240
# class_num = 1
# net_name = 'efficientnet-b1'
resize_size = int(1440 / 2560 * input_size)

def loaddata(data_dir, batch_size, set_name, shuffle):
    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize(input_size),
            transforms.CenterCrop(input_size),
            transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
            transforms.Resize(resize_size),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'show': transforms.Compose([
            transforms.Resize(input_size),
            transforms.CenterCrop(input_size),
            transforms.ToTensor()
        ]),
    }

    image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in [set_name]}
    # num_workers=0 if CPU else =1
    dataset_loaders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                                      batch_size=batch_size,
                                                      shuffle=shuffle, num_workers=1) for x in [set_name]}
    data_set_sizes = len(image_datasets[set_name])
    return dataset_loaders, data_set_sizes



def showUntilExit(title, img):
    cv2.namedWindow(title, cv2.WINDOW_KEEPRATIO)
    cv2.imshow(title, img)

    wait_time = 1000
    while cv2.getWindowProperty(title, cv2.WND_PROP_VISIBLE) >= 1:
        keyCode = cv2.waitKey(wait_time)
        if (keyCode & 0xFF) == ord("q"):
            cv2.destroyAllWindows()
            break
import cv2
import multiprocessing
import time


def classOfT(orig, model):

    outputs = model(orig)
    # loss = criterion(outputs, labels)
    _, preds = torch.max(outputs.data, 1)
    return preds[0], outputs.data

def classOf(orig, model, needhotmap = False):
    precc = orig
    # precc = precc[:,:,::-1]
    # precc = precc.reshape(1, precc.shape[0], precc.shape[1], precc.shape[2])
    # precc = precc.copy()
    # # precc = torch.Tensor(precc)
    inputs = precc
    tr = transforms.Compose([
        transforms.Resize(resize_size),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
    )
    inputs = tr(inputs)

    inputs = torch.reshape(inputs, (1,3,input_size,input_size) )
    # inputs = tr.forward(precc)
    # labels = torch.squeeze(labels.type(torch.LongTensor))
    if use_gpu:
        inputs = Variable(inputs.cuda())
    else:
        inputs = Variable(inputs)
    outputs = model(inputs)
    # loss = criterion(outputs, labels)
    # print(outputs)

    if needhotmap:
        attmap = orig.copy()
        attpix = [[0.0] * resize_size for _ in range(input_size)]
        origv = outputs.data[0][preds[0]]
        baseh = (input_size - resize_size) // 2

        diffimg = inputs.clone()
        prev = None
        for i in range(input_size):
            for j in range(resize_size):
        # for i in range(3):
        #     for j in range(3):
                print(i, j)
                diff = 0
                for ch in range(3):
                    if prev is not None:
                        pi,pj,pch = prev
                        diffimg[0][pch][baseh+pj][pi] -= 0.01
                    prev = (i,j,ch)
                    diffimg[0][ch][baseh+j][i] += 0.01

                    _, ndata = classOfT(diffimg, model)
                    diff += math.fabs(ndata[0][preds[0]] - origv)
                attpix[i][j] = diff
                pass
        vmax= max(attpix)
        vmax = max(vmax)
        for i in range(input_size):
            for j in range(resize_size):
                attmap.putpixel((i, j), (int(attpix[i][j] / vmax * 255), 0, 0))
        attmap.save("geeks.jpg")

    return outputs.data[0], outputs.data[0]



f =  open('list.csv', 'r')
sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]

list = [s.split(',')[0] for s in sources]


def test_model_online(model_reg, model_class, criterion):
    model_reg.eval()
    model_class.eval()
    running_loss = 0.0
    running_corrects = 0
    # cont = 0
    outPre = []
    outLabel = []




    import os, sys
    path = './test'

    files = os.listdir(path)
    class2id = {'full': 2, 'notfull': 3, 'dump': 0, 'fold': 1}
    cnt = 0
    lastres = [None for _ in range(len(list))]
    while True:
        ok = False
        for id, url in enumerate(list):
            cnt += 1
            print (url)
            fourcc = 'h265'
            cap = cv2.VideoCapture(url, cv2.CAP_FFMPEG)
            # cap.release()
            # cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*fourcc))
            # cap.grab()
            # cap.retrieve()
            _, img = cap.read()
            if img is not None:
                if len(img) > 700:
                    # to PIL
                    pimg = Image.fromarray(cv2.cvtColor(img,cv2.COLOR_BGR2RGB))
                    # cv2.imwrite('./img/{}_{}.jpg'.format(int(time.time()), cnt), img)
                    classconf, _ = classOf(pimg, model_class)
                    _, predid = torch.max(classconf, 0)
                    if predid > 1:
                        pred, _ = classOf(pimg, model_reg)
                        # score = math.fabs(outputs.data[0][0] - outputs.data[0][1])
                        score = pred[0].item() # -1  1

                        ssc = math.floor(score*5 + 5 + 0.5)
                        ssc = max(ssc, 0)
                        ssc = min(ssc, 10)
                        # running_corrects += torch.sum(preds == labels.data)
                        if lastres[id] is None or abs(ssc - lastres[id]) > 1: # only save when pred change
                            lastres[id] = ssc
                            image = img
                            # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                            # showUntilExit(str(int(preds[0])), image)

                            print(ssc)
                            f  = '{}_{}.jpg'.format(int(time.time()), ssc)
                            cv2.imwrite("./{}/{}".format(ssc, f), image)

                    ok = True
                else:
                    print("not h264")
            else:
                print('read error')
            cap.release()
        print('sleep...')
        if ok: time.sleep(6)
        else:  time.sleep(5)
    for f in files:
        orig = cv2.imread(path + '/' + f)
        labels = [1]
    # for show in show_loaders['show']:
        # orig, labels = show
        # inputs, labels = data
        # orig = cv2.resize(orig, (input_size, input_size))
        pred = classOf(orig, model)
        # score = math.fabs(outputs.data[0][0] - outputs.data[0][1])

        # running_corrects += torch.sum(preds == labels.data)

        image = orig
        # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        # showUntilExit(str(int(preds[0])), image)
        className = ''
        for cls in class2id:
            if class2id[cls] == int(pred):
                className = cls
                break

        cv2.imwrite("./{}/{}.png".format(className, f), image)

        # if (preds == 1):
        #     if 0: cv2.imwrite("./{}_pred_{}_score_{:.4f}.png".format(
        #         cont, int(preds[0]), score), image)
        # else:
        #     cv2.imwrite("./{}_pred_{}_score_{:.4f}.png".format(
        #         cont, int(preds[0]), score), image)
            # print(score)

        # cont += 1


def test_model_local(model, criterion):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    # cont = 0
    outPre = []
    outLabel = []




    import os, sys
    path = './test'

    files = os.listdir(path)
    class2id = {'full': 2, 'notfull': 3, 'dump': 0, 'fold': 1}
    cnt = 0
    from PIL import Image
    for f in files:
        orig = Image.open(path + '/' + f)
        labels = [1]
        tr = transforms.Compose([
                transforms.Resize(resize_size),
                # transforms.CenterCrop(input_size),
            ]
            )
        inputs = tr(orig)
        inputs.save('orig.jpg')
        pred, data = classOf(inputs, model)


        # score = math.fabs(outputs.data[0][0] - outputs.data[0][1])

        # running_corrects += torch.sum(preds == labels.data)

        image = orig
        # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        # showUntilExit(str(int(preds[0])), image)
        className = ''
        for cls in class2id:
            if class2id[cls] == int(pred):
                className = cls
                break
        print(f, data, pred)
        # cv2.imwrite("./{}/{}.png".format(className, f), image)
        shutil.copyfile(path + '/' + f, "./{}/{}".format(className, f))

        # if (preds == 1):
        #     if 0: cv2.imwrite("./{}_pred_{}_score_{:.4f}.png".format(
        #         cont, int(preds[0]), score), image)
        # else:
        #     cv2.imwrite("./{}_pred_{}_score_{:.4f}.png".format(
        #         cont, int(preds[0]), score), image)
            # print(score)

        # cont += 1


def test_model(model, criterion):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    cont = 0
    outPre = []
    outLabel = []
    gpt = [[0 for _ in range(class_num)] for _ in range(class_num)]
    dset_loaders, dset_sizes = loaddata(
        data_dir=data_dir, batch_size=batch_size, set_name='test', shuffle=False)
    for i, data in enumerate(dset_loaders['test']):
        sample_fname, _ = dset_loaders['test'].dataset.samples[i]
        inputs, labels = data
        labels = torch.squeeze(labels.type(torch.LongTensor))
        inputs, labels = Variable(inputs), Variable(labels)
        outputs = model(inputs)
        _, preds = torch.max(outputs.data, 1)
        print(sample_fname, outputs.data, preds[0])
        try:
            loss = criterion(outputs, labels)
        except:
            # print("==============error===============")
            # print(outputs, labels)
            loss = None
        # if cont == 0:
        #     outPre = outputs.data.cpu()
        #     outLabel = labels.data.cpu()
        # else:
        #     outPre = torch.cat((outPre, outputs.data.cpu()), 0)
        #     outLabel = torch.cat((outLabel, labels.data.cpu()), 0)
        if loss is not None: running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
        cont += 1
    print('===============Loss: {:.4f} Acc: {} / {}'.format(running_loss / dset_sizes,
                                            running_corrects, dset_sizes))
    return running_loss / dset_sizes



def exp_lr_scheduler(optimizer, epoch, init_lr=0.01, lr_decay_epoch=10):
    """Decay learning rate by a f#            model_out_path ="./model/W_epoch_{}.pth".format(epoch)
#            torch.save(model_W, model_out_path) actor of 0.1 every lr_decay_epoch epochs."""
    lr = init_lr * (0.8**(epoch // lr_decay_epoch))
    print('LR is set to {}'.format(lr))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return optimizer

def EffModel(net_name, net_weight, class_num):
    # pth_map = {
    #     'efficientnet-b0': 'efficientnet-b0-355c32eb.pth',
    #     'efficientnet-b1': 'efficientnet-b1-f1951068.pth',
    #     'efficientnet-b2': 'efficientnet-b2-8bb594d6.pth',
    #     'efficientnet-b3': 'efficientnet-b3-5fb5a3c3.pth',
    #     'efficientnet-b4': 'efficientnet-b4-6ed6700e.pth',
    #     'efficientnet-b5': 'efficientnet-b5-b6417697.pth',
    #     'efficientnet-b6': 'efficientnet-b6-c76e70fd.pth',
    #     'efficientnet-b7': 'efficientnet-b7-dcc49843.pth',
    # }
    # # 自动下载到本地预训练
    # # model_ft = EfficientNet.from_pretrained('efficientnet-b0')
    # # 离线加载预训练，需要事先下载好
    # model_ft = EfficientNet.from_name(net_name)

    # # 修改全连接层
    # num_ftrs = model_ft._fc.in_features
    # model_ft._fc = nn.Linear(num_ftrs, class_num)

    # net_weight = '../clothes_classify/model/' + net_name + ".pth"
    if use_gpu:
        model_ft = torch.load(net_weight)
    else:
        # what if class_num miss match when load weights???
        model_ft = torch.load(net_weight, map_location=torch.device('cpu'))
    # .load_state_dict(state_dict)
    if use_gpu:
        model_ft = model_ft.cuda()
    return model_ft


if __name__ == '__main__':
    # train

    model_reg = EffModel('efficientnet-b1', '../clothes_classify/model/efficientnet-b1.pth', 1)
    model_class = EffModel('efficientnet-b1', '../clothes_classify/model/efficientnet-b1-4class.pth', 4)
    # test
    print('-' * 10)
    print('Test Accuracy:')
    criterion = nn.CrossEntropyLoss().cuda()
    test_model_online(model_reg, model_class, criterion)
    test_model_local(model_ft, criterion)
    # test_model(model_ft, criterion)
